//
// Copyright (c) 2008 University of Washington
//
// SPDX-License-Identifier: GPL-2.0-only
//

#include "ipv4-global-routing.h"

#include "global-route-manager.h"
#include "ipv4-route.h"
#include "ipv4-routing-table-entry.h"

#include "ns3/boolean.h"
#include "ns3/log.h"
#include "ns3/names.h"
#include "ns3/net-device.h"
#include "ns3/node.h"
#include "ns3/object.h"
#include "ns3/packet.h"
#include "ns3/simulator.h"

#include <iomanip>
#include <vector>

namespace ns3
{

NS_LOG_COMPONENT_DEFINE("Ipv4GlobalRouting");

NS_OBJECT_ENSURE_REGISTERED(Ipv4GlobalRouting);

TypeId
Ipv4GlobalRouting::GetTypeId()
{
    static TypeId tid =
        TypeId("ns3::Ipv4GlobalRouting")
            .SetParent<Object>()
            .SetGroupName("Internet")
            .AddAttribute("RandomEcmpRouting",
                          "Set to true if packets are randomly routed among ECMP; set to false for "
                          "using only one route consistently",
                          BooleanValue(false),
                          MakeBooleanAccessor(&Ipv4GlobalRouting::m_randomEcmpRouting),
                          MakeBooleanChecker())
            .AddAttribute("RespondToInterfaceEvents",
                          "Set to true if you want to dynamically recompute the global routes upon "
                          "Interface notification events (up/down, or add/remove address)",
                          BooleanValue(false),
                          MakeBooleanAccessor(&Ipv4GlobalRouting::m_respondToInterfaceEvents),
                          MakeBooleanChecker());
    return tid;
}

Ipv4GlobalRouting::Ipv4GlobalRouting()
    : m_randomEcmpRouting(false),
      m_respondToInterfaceEvents(false)
{
    NS_LOG_FUNCTION(this);

    m_rand = CreateObject<UniformRandomVariable>();
}

Ipv4GlobalRouting::~Ipv4GlobalRouting()
{
    NS_LOG_FUNCTION(this);
}

void
Ipv4GlobalRouting::AddHostRouteTo(Ipv4Address dest, Ipv4Address nextHop, uint32_t interface)
{
    NS_LOG_FUNCTION(this << dest << nextHop << interface);
    auto route = new Ipv4RoutingTableEntry();
    *route = Ipv4RoutingTableEntry::CreateHostRouteTo(dest, nextHop, interface);
    for (auto routePointer : m_hostRoutes)
    {
        if (*routePointer == *route)
        {
            NS_LOG_LOGIC("Route already exists");
            delete route;
            return;
        }
    }
    m_hostRoutes.push_back(route);
}

void
Ipv4GlobalRouting::AddHostRouteTo(Ipv4Address dest, uint32_t interface)
{
    NS_LOG_FUNCTION(this << dest << interface);
    auto route = new Ipv4RoutingTableEntry();
    *route = Ipv4RoutingTableEntry::CreateHostRouteTo(dest, interface);
    for (auto routePointer : m_hostRoutes)
    {
        if (*routePointer == *route)
        {
            NS_LOG_LOGIC("Route already exists");
            delete route;
            return;
        }
    }
    m_hostRoutes.push_back(route);
}

void
Ipv4GlobalRouting::AddNetworkRouteTo(Ipv4Address network,
                                     Ipv4Mask networkMask,
                                     Ipv4Address nextHop,
                                     uint32_t interface)
{
    NS_LOG_FUNCTION(this << network << networkMask << nextHop << interface);
    auto route = new Ipv4RoutingTableEntry();
    *route = Ipv4RoutingTableEntry::CreateNetworkRouteTo(network, networkMask, nextHop, interface);
    for (auto routePointer : m_networkRoutes)
    {
        if (*routePointer == *route)
        {
            NS_LOG_LOGIC("Route already exists");
            delete route;
            return;
        }
    }
    m_networkRoutes.push_back(route);
}

void
Ipv4GlobalRouting::AddNetworkRouteTo(Ipv4Address network, Ipv4Mask networkMask, uint32_t interface)
{
    NS_LOG_FUNCTION(this << network << networkMask << interface);
    auto route = new Ipv4RoutingTableEntry();
    *route = Ipv4RoutingTableEntry::CreateNetworkRouteTo(network, networkMask, interface);
    for (auto routePointer : m_networkRoutes)
    {
        if (*routePointer == *route)
        {
            NS_LOG_LOGIC("Route already exists");
            delete route;
            return;
        }
    }
    m_networkRoutes.push_back(route);
}

void
Ipv4GlobalRouting::AddASExternalRouteTo(Ipv4Address network,
                                        Ipv4Mask networkMask,
                                        Ipv4Address nextHop,
                                        uint32_t interface)
{
    NS_LOG_FUNCTION(this << network << networkMask << nextHop << interface);
    auto route = new Ipv4RoutingTableEntry();
    *route = Ipv4RoutingTableEntry::CreateNetworkRouteTo(network, networkMask, nextHop, interface);
    for (auto routePointer : m_ASexternalRoutes)
    {
        if (*routePointer == *route)
        {
            NS_LOG_LOGIC("Route already exists");
            delete route;
            return;
        }
    }
    m_ASexternalRoutes.push_back(route);
}

Ptr<Ipv4Route>
Ipv4GlobalRouting::LookupGlobal(Ipv4Address dest, Ptr<NetDevice> oif)
{
    NS_LOG_FUNCTION(this << dest << oif);
    NS_LOG_LOGIC("Looking for route for destination " << dest);
    Ptr<Ipv4Route> rtentry = nullptr;
    // store all available routes that bring packets to their destination
    typedef std::vector<Ipv4RoutingTableEntry*> RouteVec_t;
    RouteVec_t allRoutes;

    NS_LOG_LOGIC("Number of m_hostRoutes = " << m_hostRoutes.size());
    for (auto i = m_hostRoutes.begin(); i != m_hostRoutes.end(); i++)
    {
        NS_ASSERT((*i)->IsHost());
        if ((*i)->GetDest() == dest)
        {
            if (oif)
            {
                if (oif != m_ipv4->GetNetDevice((*i)->GetInterface()))
                {
                    NS_LOG_LOGIC("Not on requested interface, skipping");
                    continue;
                }
            }
            allRoutes.push_back(*i);
            NS_LOG_LOGIC(allRoutes.size() << "Found global host route" << *i);
        }
    }
    if (allRoutes.empty()) // if no host route is found
    {
        NS_LOG_LOGIC("Number of m_networkRoutes" << m_networkRoutes.size());
        // store the length of the longest mask.
        uint16_t longest_mask = 0;
        for (auto j = m_networkRoutes.begin(); j != m_networkRoutes.end(); j++)
        {
            Ipv4Mask mask = (*j)->GetDestNetworkMask();
            uint16_t masklen = mask.GetPrefixLength();

            Ipv4Address entry = (*j)->GetDestNetwork();
            if (mask.IsMatch(dest, entry))
            {
                if (oif)
                {
                    if (oif != m_ipv4->GetNetDevice((*j)->GetInterface()))
                    {
                        NS_LOG_LOGIC("Not on requested interface, skipping");
                        continue;
                    }
                }
                NS_LOG_LOGIC(allRoutes.size() << "Found global network route" << *j);
                if (masklen < longest_mask) // Not interested if got shorter mask
                {
                    NS_LOG_LOGIC("Previous match longer, skipping");
                    continue;
                }
                else if (masklen == longest_mask)
                {
                    NS_LOG_LOGIC("Equal mask length, adding this to the list");
                    allRoutes.push_back(*j);
                }
                else
                {
                    NS_LOG_LOGIC("Longer mask length found, clearing the list and adding");
                    allRoutes.clear();
                    allRoutes.push_back(*j);
                }
            }
        }
    }
    if (allRoutes.empty()) // consider external if no host/network found
    {
        for (auto k = m_ASexternalRoutes.begin(); k != m_ASexternalRoutes.end(); k++)
        {
            Ipv4Mask mask = (*k)->GetDestNetworkMask();
            Ipv4Address entry = (*k)->GetDestNetwork();
            if (mask.IsMatch(dest, entry))
            {
                NS_LOG_LOGIC("Found external route" << *k);
                if (oif)
                {
                    if (oif != m_ipv4->GetNetDevice((*k)->GetInterface()))
                    {
                        NS_LOG_LOGIC("Not on requested interface, skipping");
                        continue;
                    }
                }
                allRoutes.push_back(*k);
                break;
            }
        }
    }
    if (!allRoutes.empty()) // if route(s) is found
    {
        // pick up one of the routes uniformly at random if random
        // ECMP routing is enabled, or always select the first route
        // consistently if random ECMP routing is disabled
        uint32_t selectIndex;
        if (m_randomEcmpRouting)
        {
            selectIndex = m_rand->GetInteger(0, allRoutes.size() - 1);
        }
        else
        {
            selectIndex = 0;
        }
        Ipv4RoutingTableEntry* route = allRoutes.at(selectIndex);
        // create a Ipv4Route object from the selected routing table entry
        rtentry = Create<Ipv4Route>();
        rtentry->SetDestination(route->GetDest());
        /// @todo handle multi-address case
        rtentry->SetSource(m_ipv4->GetAddress(route->GetInterface(), 0).GetLocal());
        rtentry->SetGateway(route->GetGateway());
        uint32_t interfaceIdx = route->GetInterface();
        rtentry->SetOutputDevice(m_ipv4->GetNetDevice(interfaceIdx));
        return rtentry;
    }
    else
    {
        return nullptr;
    }
}

uint32_t
Ipv4GlobalRouting::GetNRoutes() const
{
    NS_LOG_FUNCTION(this);
    uint32_t n = 0;
    n += m_hostRoutes.size();
    n += m_networkRoutes.size();
    n += m_ASexternalRoutes.size();
    return n;
}

Ipv4RoutingTableEntry*
Ipv4GlobalRouting::GetRoute(uint32_t index) const
{
    NS_LOG_FUNCTION(this << index);
    if (index < m_hostRoutes.size())
    {
        uint32_t tmp = 0;
        for (auto i = m_hostRoutes.begin(); i != m_hostRoutes.end(); i++)
        {
            if (tmp == index)
            {
                return *i;
            }
            tmp++;
        }
    }
    index -= m_hostRoutes.size();
    uint32_t tmp = 0;
    if (index < m_networkRoutes.size())
    {
        for (auto j = m_networkRoutes.begin(); j != m_networkRoutes.end(); j++)
        {
            if (tmp == index)
            {
                return *j;
            }
            tmp++;
        }
    }
    index -= m_networkRoutes.size();
    tmp = 0;
    for (auto k = m_ASexternalRoutes.begin(); k != m_ASexternalRoutes.end(); k++)
    {
        if (tmp == index)
        {
            return *k;
        }
        tmp++;
    }
    NS_ASSERT(false);
    // quiet compiler.
    return nullptr;
}

void
Ipv4GlobalRouting::RemoveRoute(uint32_t index)
{
    NS_LOG_FUNCTION(this << index);
    if (index < m_hostRoutes.size())
    {
        uint32_t tmp = 0;
        for (auto i = m_hostRoutes.begin(); i != m_hostRoutes.end(); i++)
        {
            if (tmp == index)
            {
                NS_LOG_LOGIC("Removing route " << index << "; size = " << m_hostRoutes.size());
                delete *i;
                m_hostRoutes.erase(i);
                NS_LOG_LOGIC("Done removing host route "
                             << index << "; host route remaining size = " << m_hostRoutes.size());
                return;
            }
            tmp++;
        }
    }
    index -= m_hostRoutes.size();
    uint32_t tmp = 0;
    for (auto j = m_networkRoutes.begin(); j != m_networkRoutes.end(); j++)
    {
        if (tmp == index)
        {
            NS_LOG_LOGIC("Removing route " << index << "; size = " << m_networkRoutes.size());
            delete *j;
            m_networkRoutes.erase(j);
            NS_LOG_LOGIC("Done removing network route "
                         << index << "; network route remaining size = " << m_networkRoutes.size());
            return;
        }
        tmp++;
    }
    index -= m_networkRoutes.size();
    tmp = 0;
    for (auto k = m_ASexternalRoutes.begin(); k != m_ASexternalRoutes.end(); k++)
    {
        if (tmp == index)
        {
            NS_LOG_LOGIC("Removing route " << index << "; size = " << m_ASexternalRoutes.size());
            delete *k;
            m_ASexternalRoutes.erase(k);
            NS_LOG_LOGIC("Done removing network route "
                         << index << "; network route remaining size = " << m_networkRoutes.size());
            return;
        }
        tmp++;
    }
    NS_ASSERT(false);
}

int64_t
Ipv4GlobalRouting::AssignStreams(int64_t stream)
{
    NS_LOG_FUNCTION(this << stream);
    m_rand->SetStream(stream);
    return 1;
}

void
Ipv4GlobalRouting::DoDispose()
{
    NS_LOG_FUNCTION(this);
    for (auto i = m_hostRoutes.begin(); i != m_hostRoutes.end(); i = m_hostRoutes.erase(i))
    {
        delete (*i);
    }
    for (auto j = m_networkRoutes.begin(); j != m_networkRoutes.end(); j = m_networkRoutes.erase(j))
    {
        delete (*j);
    }
    for (auto l = m_ASexternalRoutes.begin(); l != m_ASexternalRoutes.end();
         l = m_ASexternalRoutes.erase(l))
    {
        delete (*l);
    }

    Ipv4RoutingProtocol::DoDispose();
}

// Formatted like output of "route -n" command
void
Ipv4GlobalRouting::PrintRoutingTable(Ptr<OutputStreamWrapper> stream, Time::Unit unit) const
{
    NS_LOG_FUNCTION(this << stream);
    std::ostream* os = stream->GetStream();
    // Copy the current ostream state
    std::ios oldState(nullptr);
    oldState.copyfmt(*os);

    *os << std::resetiosflags(std::ios::adjustfield) << std::setiosflags(std::ios::left);

    *os << "Node: " << m_ipv4->GetObject<Node>()->GetId() << ", Time: " << Now().As(unit)
        << ", Local time: " << m_ipv4->GetObject<Node>()->GetLocalTime().As(unit)
        << ", Ipv4GlobalRouting table" << std::endl;

    if (GetNRoutes() > 0)
    {
        *os << "Destination     Gateway         Genmask         Flags Metric Ref    Use Iface"
            << std::endl;
        for (uint32_t j = 0; j < GetNRoutes(); j++)
        {
            std::ostringstream dest;
            std::ostringstream gw;
            std::ostringstream mask;
            std::ostringstream flags;
            Ipv4RoutingTableEntry route = GetRoute(j);
            dest << route.GetDest();
            *os << std::setw(16) << dest.str();
            gw << route.GetGateway();
            *os << std::setw(16) << gw.str();
            mask << route.GetDestNetworkMask();
            *os << std::setw(16) << mask.str();
            flags << "U";
            if (route.IsHost())
            {
                flags << "H";
            }
            else if (route.IsGateway())
            {
                flags << "G";
            }
            *os << std::setw(6) << flags.str();
            // Metric not implemented
            *os << "-"
                << "      ";
            // Ref ct not implemented
            *os << "-"
                << "      ";
            // Use not implemented
            *os << "-"
                << "   ";
            if (!Names::FindName(m_ipv4->GetNetDevice(route.GetInterface())).empty())
            {
                *os << Names::FindName(m_ipv4->GetNetDevice(route.GetInterface()));
            }
            else
            {
                *os << route.GetInterface();
            }
            *os << std::endl;
        }
    }
    *os << std::endl;
    // Restore the previous ostream state
    (*os).copyfmt(oldState);
}

Ptr<Ipv4Route>
Ipv4GlobalRouting::RouteOutput(Ptr<Packet> p,
                               const Ipv4Header& header,
                               Ptr<NetDevice> oif,
                               Socket::SocketErrno& sockerr)
{
    NS_LOG_FUNCTION(this << p << &header << oif << &sockerr);
    //
    // First, see if this is a multicast packet we have a route for.  If we
    // have a route, then send the packet down each of the specified interfaces.
    //
    if (header.GetDestination().IsMulticast())
    {
        NS_LOG_LOGIC("Multicast destination-- returning false");
        return nullptr; // Let other routing protocols try to handle this
    }
    //
    // See if this is a unicast packet we have a route for.
    //
    NS_LOG_LOGIC("Unicast destination- looking up");
    Ptr<Ipv4Route> rtentry = LookupGlobal(header.GetDestination(), oif);
    if (rtentry)
    {
        sockerr = Socket::ERROR_NOTERROR;
    }
    else
    {
        sockerr = Socket::ERROR_NOROUTETOHOST;
    }
    return rtentry;
}

bool
Ipv4GlobalRouting::RouteInput(Ptr<const Packet> p,
                              const Ipv4Header& header,
                              Ptr<const NetDevice> idev,
                              const UnicastForwardCallback& ucb,
                              const MulticastForwardCallback& mcb,
                              const LocalDeliverCallback& lcb,
                              const ErrorCallback& ecb)
{
    NS_LOG_FUNCTION(this << p << header << header.GetSource() << header.GetDestination() << idev
                         << &lcb << &ecb);
    // Check if input device supports IP
    NS_ASSERT(m_ipv4->GetInterfaceForDevice(idev) >= 0);
    uint32_t iif = m_ipv4->GetInterfaceForDevice(idev);

    if (m_ipv4->IsDestinationAddress(header.GetDestination(), iif))
    {
        if (!lcb.IsNull())
        {
            NS_LOG_LOGIC("Local delivery to " << header.GetDestination());
            lcb(p, header, iif);
            return true;
        }
        else
        {
            // The local delivery callback is null.  This may be a multicast
            // or broadcast packet, so return false so that another
            // multicast routing protocol can handle it.  It should be possible
            // to extend this to explicitly check whether it is a unicast
            // packet, and invoke the error callback if so
            return false;
        }
    }

    // Check if input device supports IP forwarding
    if (!m_ipv4->IsForwarding(iif))
    {
        NS_LOG_LOGIC("Forwarding disabled for this interface");
        ecb(p, header, Socket::ERROR_NOROUTETOHOST);
        return true;
    }
    // Next, try to find a route
    NS_LOG_LOGIC("Unicast destination- looking up global route");
    Ptr<Ipv4Route> rtentry = LookupGlobal(header.GetDestination());
    if (rtentry)
    {
        NS_LOG_LOGIC("Found unicast destination- calling unicast callback");
        ucb(rtentry, p, header);
        return true;
    }
    else
    {
        NS_LOG_LOGIC("Did not find unicast destination- returning false");
        return false; // Let other routing protocols try to handle this
                      // route request.
    }
}

void
Ipv4GlobalRouting::NotifyInterfaceUp(uint32_t i)
{
    NS_LOG_FUNCTION(this << i);
    if (m_respondToInterfaceEvents && Simulator::Now().GetSeconds() > 0) // avoid startup events
    {
        GlobalRouteManager::DeleteGlobalRoutes();
        GlobalRouteManager::BuildGlobalRoutingDatabase();
        GlobalRouteManager::InitializeRoutes();
    }
}

void
Ipv4GlobalRouting::NotifyInterfaceDown(uint32_t i)
{
    NS_LOG_FUNCTION(this << i);
    if (m_respondToInterfaceEvents && Simulator::Now().GetSeconds() > 0) // avoid startup events
    {
        GlobalRouteManager::DeleteGlobalRoutes();
        GlobalRouteManager::BuildGlobalRoutingDatabase();
        GlobalRouteManager::InitializeRoutes();
    }
}

void
Ipv4GlobalRouting::NotifyAddAddress(uint32_t interface, Ipv4InterfaceAddress address)
{
    NS_LOG_FUNCTION(this << interface << address);
    if (m_respondToInterfaceEvents && Simulator::Now().GetSeconds() > 0) // avoid startup events
    {
        GlobalRouteManager::DeleteGlobalRoutes();
        GlobalRouteManager::BuildGlobalRoutingDatabase();
        GlobalRouteManager::InitializeRoutes();
    }
}

void
Ipv4GlobalRouting::NotifyRemoveAddress(uint32_t interface, Ipv4InterfaceAddress address)
{
    NS_LOG_FUNCTION(this << interface << address);
    if (m_respondToInterfaceEvents && Simulator::Now().GetSeconds() > 0) // avoid startup events
    {
        GlobalRouteManager::DeleteGlobalRoutes();
        GlobalRouteManager::BuildGlobalRoutingDatabase();
        GlobalRouteManager::InitializeRoutes();
    }
}

void
Ipv4GlobalRouting::SetIpv4(Ptr<Ipv4> ipv4)
{
    NS_LOG_FUNCTION(this << ipv4);
    NS_ASSERT(!m_ipv4 && ipv4);
    m_ipv4 = ipv4;
}

} // namespace ns3
